package weka.classifiers.meta;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableMultipleClassifiersCombiner;
import weka.classifiers.lazy.kstar.KStarConstants;
import weka.classifiers.rules.ZeroR;
import weka.core.Aggregateable;
import weka.core.Capabilities;
import weka.core.Environment;
import weka.core.EnvironmentHandler;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

/* loaded from: input_file:weka/classifiers/meta/Vote.class */
public class Vote extends RandomizableMultipleClassifiersCombiner implements TechnicalInformationHandler, EnvironmentHandler, Aggregateable<Classifier> {
    static final long serialVersionUID = -637891196294399624L;
    public static final int AVERAGE_RULE = 1;
    public static final int PRODUCT_RULE = 2;
    public static final int MAJORITY_VOTING_RULE = 3;
    public static final int MIN_RULE = 4;
    public static final int MAX_RULE = 5;
    public static final int MEDIAN_RULE = 6;
    public static final Tag[] TAGS_RULES = {new Tag(1, "AVG", "Average of Probabilities"), new Tag(2, "PROD", "Product of Probabilities"), new Tag(3, "MAJ", "Majority Voting"), new Tag(4, "MIN", "Minimum Probability"), new Tag(5, "MAX", "Maximum Probability"), new Tag(6, "MED", "Median")};
    protected Random m_Random;
    protected Instances m_structure;
    protected int m_CombinationRule = 1;
    protected List<String> m_classifiersToLoad = new ArrayList();
    protected List<Classifier> m_preBuiltClassifiers = new ArrayList();
    protected transient Environment m_env = Environment.getSystemWide();

    public String globalInfo() {
        return "Class for combining classifiers. Different combinations of probability estimates for classification are available.\n\nFor more information see:\n\n" + getTechnicalInformation().toString();
    }

    @Override // weka.classifiers.RandomizableMultipleClassifiersCombiner, weka.classifiers.MultipleClassifiersCombiner, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector();
        Enumeration listOptions = super.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement(listOptions.nextElement());
        }
        vector.addElement(new Option("\tFull path to serialized classifier to include.\n\tMay be specified multiple times to include\n\tmultiple serialized classifiers. Note: it does\n\tnot make sense to use pre-built classifiers in\n\ta cross-validation.", "P", 1, "-P <path to serialized classifier>"));
        vector.addElement(new Option("\tThe combination rule to use\n\t(default: AVG)", "R", 1, "-R " + Tag.toOptionList(TAGS_RULES)));
        return vector.elements();
    }

    @Override // weka.classifiers.RandomizableMultipleClassifiersCombiner, weka.classifiers.MultipleClassifiersCombiner, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        for (String str : super.getOptions()) {
            vector.add(str);
        }
        vector.add("-R");
        vector.add(new StringBuilder().append(getCombinationRule()).toString());
        for (int i = 0; i < this.m_classifiersToLoad.size(); i++) {
            vector.add("-P");
            vector.add(this.m_classifiersToLoad.get(i));
        }
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    @Override // weka.classifiers.RandomizableMultipleClassifiersCombiner, weka.classifiers.MultipleClassifiersCombiner, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('R', strArr);
        if (option.length() != 0) {
            setCombinationRule(new SelectedTag(option, TAGS_RULES));
        } else {
            setCombinationRule(new SelectedTag(1, TAGS_RULES));
        }
        this.m_classifiersToLoad.clear();
        while (true) {
            String option2 = Utils.getOption('P', strArr);
            if (option2.length() == 0) {
                super.setOptions(strArr);
                return;
            }
            this.m_classifiersToLoad.add(option2);
        }
    }

    @Override // weka.core.TechnicalInformationHandler
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.BOOK);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Ludmila I. Kuncheva");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Combining Pattern Classifiers: Methods and Algorithms");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2004");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "John Wiley and Sons, Inc.");
        TechnicalInformation add = technicalInformation.add(TechnicalInformation.Type.ARTICLE);
        add.setValue(TechnicalInformation.Field.AUTHOR, "J. Kittler and M. Hatef and Robert P.W. Duin and J. Matas");
        add.setValue(TechnicalInformation.Field.YEAR, "1998");
        add.setValue(TechnicalInformation.Field.TITLE, "On combining classifiers");
        add.setValue(TechnicalInformation.Field.JOURNAL, "IEEE Transactions on Pattern Analysis and Machine Intelligence");
        add.setValue(TechnicalInformation.Field.VOLUME, "20");
        add.setValue(TechnicalInformation.Field.NUMBER, "3");
        add.setValue(TechnicalInformation.Field.PAGES, "226-239");
        return technicalInformation;
    }

    @Override // weka.classifiers.MultipleClassifiersCombiner, weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        if (this.m_preBuiltClassifiers.size() == 0 && this.m_classifiersToLoad.size() > 0) {
            try {
                loadClassifiers(null);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        if (this.m_preBuiltClassifiers.size() > 0) {
            if (this.m_Classifiers.length == 0) {
                capabilities = (Capabilities) this.m_preBuiltClassifiers.get(0).getCapabilities().clone();
            }
            for (int i = 1; i < this.m_preBuiltClassifiers.size(); i++) {
                capabilities.and(this.m_preBuiltClassifiers.get(i).getCapabilities());
            }
            for (Capabilities.Capability capability : Capabilities.Capability.valuesCustom()) {
                capabilities.enableDependency(capability);
            }
        }
        if (this.m_CombinationRule == 2 || this.m_CombinationRule == 3) {
            capabilities.disableAllClasses();
            capabilities.disableAllClassDependencies();
            capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
            capabilities.enableDependency(Capabilities.Capability.NOMINAL_CLASS);
        } else if (this.m_CombinationRule == 6) {
            capabilities.disableAllClasses();
            capabilities.disableAllClassDependencies();
            capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
            capabilities.enableDependency(Capabilities.Capability.NUMERIC_CLASS);
        }
        return capabilities;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        this.m_structure = new Instances(instances2, 0);
        this.m_Random = new Random(getSeed());
        if (this.m_classifiersToLoad.size() > 0) {
            this.m_preBuiltClassifiers.clear();
            loadClassifiers(instances);
            if (this.m_Classifiers.length == 1 && (this.m_Classifiers[0] instanceof ZeroR)) {
                this.m_Classifiers = new Classifier[0];
            }
        }
        getCapabilities().testWithFail(instances);
        for (int i = 0; i < this.m_Classifiers.length; i++) {
            getClassifier(i).buildClassifier(instances2);
        }
    }

    private void loadClassifiers(Instances instances) throws Exception {
        Iterator<String> it = this.m_classifiersToLoad.iterator();
        while (it.hasNext()) {
            String next = it.next();
            if (Environment.containsEnvVariables(next)) {
                try {
                    next = this.m_env.substitute(next);
                } catch (Exception e) {
                }
            }
            File file = new File(next);
            if (!file.isFile()) {
                throw new Exception("\"" + next + "\" does not seem to be a valid file!");
            }
            ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(new FileInputStream(file)));
            Object readObject = objectInputStream.readObject();
            if (!(readObject instanceof Classifier)) {
                throw new Exception("\"" + next + "\" does not contain a classifier!");
            }
            Object readObject2 = objectInputStream.readObject();
            if ((readObject2 instanceof Instances) && instances != null && !instances.equalHeaders((Instances) readObject2)) {
                throw new Exception("\"" + next + "\" was trained with data that is of a differnet structure than the incoming training data");
            }
            if (readObject2 == null) {
                System.out.println("[Vote] warning: no header instances for \"" + next + "\"");
            }
            addPreBuiltClassifier((Classifier) readObject);
        }
    }

    public void addPreBuiltClassifier(Classifier classifier) {
        this.m_preBuiltClassifiers.add(classifier);
    }

    public void removePreBuiltClassifier(Classifier classifier) {
        this.m_preBuiltClassifiers.remove(classifier);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double classifyInstance(Instance instance) throws Exception {
        double classifyInstanceMedian;
        switch (this.m_CombinationRule) {
            case 1:
            case 2:
            case 3:
            case 4:
            case 5:
                double[] distributionForInstance = distributionForInstance(instance);
                if (!instance.classAttribute().isNominal()) {
                    if (!instance.classAttribute().isNumeric()) {
                        classifyInstanceMedian = Utils.missingValue();
                        break;
                    } else {
                        classifyInstanceMedian = distributionForInstance[0];
                        break;
                    }
                } else {
                    int maxIndex = Utils.maxIndex(distributionForInstance);
                    if (distributionForInstance[maxIndex] != KStarConstants.FLOOR) {
                        classifyInstanceMedian = maxIndex;
                        break;
                    } else {
                        classifyInstanceMedian = Utils.missingValue();
                        break;
                    }
                }
            case 6:
                classifyInstanceMedian = classifyInstanceMedian(instance);
                break;
            default:
                throw new IllegalStateException("Unknown combination rule '" + this.m_CombinationRule + "'!");
        }
        return classifyInstanceMedian;
    }

    protected double classifyInstanceMedian(Instance instance) throws Exception {
        double[] dArr = new double[this.m_Classifiers.length + this.m_preBuiltClassifiers.size()];
        for (int i = 0; i < this.m_Classifiers.length; i++) {
            dArr[i] = this.m_Classifiers[i].classifyInstance(instance);
        }
        for (int i2 = 0; i2 < this.m_preBuiltClassifiers.size(); i2++) {
            dArr[i2 + this.m_Classifiers.length] = this.m_preBuiltClassifiers.get(i2).classifyInstance(instance);
        }
        return dArr.length == 0 ? 0.0d : dArr.length == 1 ? dArr[0] : Utils.kthSmallestValue(dArr, dArr.length / 2);
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] dArr = new double[instance.numClasses()];
        switch (this.m_CombinationRule) {
            case 1:
                dArr = distributionForInstanceAverage(instance);
                break;
            case 2:
                dArr = distributionForInstanceProduct(instance);
                break;
            case 3:
                dArr = distributionForInstanceMajorityVoting(instance);
                break;
            case 4:
                dArr = distributionForInstanceMin(instance);
                break;
            case 5:
                dArr = distributionForInstanceMax(instance);
                break;
            case 6:
                dArr[0] = classifyInstance(instance);
                break;
            default:
                throw new IllegalStateException("Unknown combination rule '" + this.m_CombinationRule + "'!");
        }
        if (!instance.classAttribute().isNumeric() && Utils.sum(dArr) > KStarConstants.FLOOR) {
            Utils.normalize(dArr);
        }
        return dArr;
    }

    protected double[] distributionForInstanceAverage(Instance instance) throws Exception {
        double[] dArr = (double[]) (this.m_Classifiers.length > 0 ? getClassifier(0).distributionForInstance(instance) : this.m_preBuiltClassifiers.get(0).distributionForInstance(instance)).clone();
        for (int i = 1; i < this.m_Classifiers.length; i++) {
            double[] distributionForInstance = getClassifier(i).distributionForInstance(instance);
            for (int i2 = 0; i2 < distributionForInstance.length; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + distributionForInstance[i2];
            }
        }
        for (int i4 = this.m_Classifiers.length > 0 ? 0 : 1; i4 < this.m_preBuiltClassifiers.size(); i4++) {
            double[] distributionForInstance2 = this.m_preBuiltClassifiers.get(i4).distributionForInstance(instance);
            for (int i5 = 0; i5 < distributionForInstance2.length; i5++) {
                int i6 = i5;
                dArr[i6] = dArr[i6] + distributionForInstance2[i5];
            }
        }
        for (int i7 = 0; i7 < dArr.length; i7++) {
            int i8 = i7;
            dArr[i8] = dArr[i8] / (this.m_Classifiers.length + this.m_preBuiltClassifiers.size());
        }
        return dArr;
    }

    protected double[] distributionForInstanceProduct(Instance instance) throws Exception {
        double[] dArr = (double[]) (this.m_Classifiers.length > 0 ? getClassifier(0).distributionForInstance(instance) : this.m_preBuiltClassifiers.get(0).distributionForInstance(instance)).clone();
        for (int i = 1; i < this.m_Classifiers.length; i++) {
            double[] distributionForInstance = getClassifier(i).distributionForInstance(instance);
            for (int i2 = 0; i2 < distributionForInstance.length; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] * distributionForInstance[i2];
            }
        }
        for (int i4 = this.m_Classifiers.length > 0 ? 0 : 1; i4 < this.m_preBuiltClassifiers.size(); i4++) {
            double[] distributionForInstance2 = this.m_preBuiltClassifiers.get(i4).distributionForInstance(instance);
            for (int i5 = 0; i5 < distributionForInstance2.length; i5++) {
                int i6 = i5;
                dArr[i6] = dArr[i6] * distributionForInstance2[i5];
            }
        }
        return dArr;
    }

    protected double[] distributionForInstanceMajorityVoting(Instance instance) throws Exception {
        double[] dArr = new double[instance.classAttribute().numValues()];
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < this.m_Classifiers.length; i++) {
            dArr = getClassifier(i).distributionForInstance(instance);
            int i2 = 0;
            for (int i3 = 0; i3 < dArr.length; i3++) {
                if (dArr[i3] > dArr[i2]) {
                    i2 = i3;
                }
            }
            for (int i4 = 0; i4 < dArr.length; i4++) {
                if (dArr[i4] == dArr[i2]) {
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] + 1.0d;
                }
            }
        }
        for (int i6 = 0; i6 < this.m_preBuiltClassifiers.size(); i6++) {
            dArr = this.m_preBuiltClassifiers.get(i6).distributionForInstance(instance);
            int i7 = 0;
            for (int i8 = 0; i8 < dArr.length; i8++) {
                if (dArr[i8] > dArr[i7]) {
                    i7 = i8;
                }
            }
            for (int i9 = 0; i9 < dArr.length; i9++) {
                if (dArr[i9] == dArr[i7]) {
                    int i10 = i9;
                    dArr2[i10] = dArr2[i10] + 1.0d;
                }
            }
        }
        int i11 = 0;
        for (int i12 = 1; i12 < dArr2.length; i12++) {
            if (dArr2[i12] > dArr2[i11]) {
                i11 = i12;
            }
        }
        Vector vector = new Vector();
        for (int i13 = 0; i13 < dArr2.length; i13++) {
            if (dArr2[i13] == dArr2[i11]) {
                vector.add(Integer.valueOf(i13));
            }
        }
        int intValue = ((Integer) vector.get(this.m_Random.nextInt(vector.size()))).intValue();
        double[] dArr3 = new double[dArr.length];
        dArr3[intValue] = 1.0d;
        return dArr3;
    }

    protected double[] distributionForInstanceMax(Instance instance) throws Exception {
        double[] dArr = (double[]) (this.m_Classifiers.length > 0 ? getClassifier(0).distributionForInstance(instance) : this.m_preBuiltClassifiers.get(0).distributionForInstance(instance)).clone();
        for (int i = 1; i < this.m_Classifiers.length; i++) {
            double[] distributionForInstance = getClassifier(i).distributionForInstance(instance);
            for (int i2 = 0; i2 < distributionForInstance.length; i2++) {
                if (dArr[i2] < distributionForInstance[i2]) {
                    dArr[i2] = distributionForInstance[i2];
                }
            }
        }
        for (int i3 = this.m_Classifiers.length > 0 ? 0 : 1; i3 < this.m_preBuiltClassifiers.size(); i3++) {
            double[] distributionForInstance2 = this.m_preBuiltClassifiers.get(i3).distributionForInstance(instance);
            for (int i4 = 0; i4 < distributionForInstance2.length; i4++) {
                if (dArr[i4] < distributionForInstance2[i4]) {
                    dArr[i4] = distributionForInstance2[i4];
                }
            }
        }
        return dArr;
    }

    protected double[] distributionForInstanceMin(Instance instance) throws Exception {
        double[] dArr = (double[]) (this.m_Classifiers.length > 0 ? getClassifier(0).distributionForInstance(instance) : this.m_preBuiltClassifiers.get(0).distributionForInstance(instance)).clone();
        for (int i = 1; i < this.m_Classifiers.length; i++) {
            double[] distributionForInstance = getClassifier(i).distributionForInstance(instance);
            for (int i2 = 0; i2 < distributionForInstance.length; i2++) {
                if (distributionForInstance[i2] < dArr[i2]) {
                    dArr[i2] = distributionForInstance[i2];
                }
            }
        }
        for (int i3 = this.m_Classifiers.length > 0 ? 0 : 1; i3 < this.m_preBuiltClassifiers.size(); i3++) {
            double[] distributionForInstance2 = this.m_preBuiltClassifiers.get(i3).distributionForInstance(instance);
            for (int i4 = 0; i4 < distributionForInstance2.length; i4++) {
                if (distributionForInstance2[i4] < dArr[i4]) {
                    dArr[i4] = distributionForInstance2[i4];
                }
            }
        }
        return dArr;
    }

    public String combinationRuleTipText() {
        return "The combination rule used.";
    }

    public SelectedTag getCombinationRule() {
        return new SelectedTag(this.m_CombinationRule, TAGS_RULES);
    }

    public void setCombinationRule(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_RULES) {
            this.m_CombinationRule = selectedTag.getSelectedTag().getID();
        }
    }

    public String preBuiltClassifiersTipText() {
        return "The pre-built serialized classifiers to include. Multiple serialized classifiers can be included alongside those that are built from scratch when this classifier runs. Note that it does not make sense to include pre-built classifiers in a cross-validation since they are static and their models do not change from fold to fold.";
    }

    public void setPreBuiltClassifiers(File[] fileArr) {
        this.m_classifiersToLoad.clear();
        if (fileArr == null || fileArr.length <= 0) {
            return;
        }
        for (File file : fileArr) {
            this.m_classifiersToLoad.add(file.toString());
        }
    }

    public File[] getPreBuiltClassifiers() {
        File[] fileArr = new File[this.m_classifiersToLoad.size()];
        for (int i = 0; i < this.m_classifiersToLoad.size(); i++) {
            fileArr[i] = new File(this.m_classifiersToLoad.get(i));
        }
        return fileArr;
    }

    public String toString() {
        String str;
        if (this.m_Classifiers == null) {
            return "Vote: No model built yet.";
        }
        String str2 = String.valueOf("Vote combines") + " the probability distributions of these base learners:\n";
        for (int i = 0; i < this.m_Classifiers.length; i++) {
            str2 = String.valueOf(str2) + '\t' + getClassifierSpec(i) + '\n';
        }
        for (Classifier classifier : this.m_preBuiltClassifiers) {
            str2 = String.valueOf(str2) + "\t" + classifier.getClass().getName() + Utils.joinOptions(((OptionHandler) classifier).getOptions()) + "\n";
        }
        String str3 = String.valueOf(str2) + "using the '";
        switch (this.m_CombinationRule) {
            case 1:
                str = String.valueOf(str3) + "Average of Probabilities";
                break;
            case 2:
                str = String.valueOf(str3) + "Product of Probabilities";
                break;
            case 3:
                str = String.valueOf(str3) + "Majority Voting";
                break;
            case 4:
                str = String.valueOf(str3) + "Minimum Probability";
                break;
            case 5:
                str = String.valueOf(str3) + "Maximum Probability";
                break;
            case 6:
                str = String.valueOf(str3) + "Median Probability";
                break;
            default:
                throw new IllegalStateException("Unknown combination rule '" + this.m_CombinationRule + "'!");
        }
        return String.valueOf(str) + "' combination rule \n";
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 9785 $");
    }

    @Override // weka.core.EnvironmentHandler
    public void setEnvironment(Environment environment) {
        this.m_env = environment;
    }

    @Override // weka.core.Aggregateable
    public Classifier aggregate(Classifier classifier) throws Exception {
        if (this.m_structure == null && this.m_Classifiers.length == 1 && (this.m_Classifiers[0] instanceof ZeroR)) {
            setClassifiers(new Classifier[0]);
        }
        addPreBuiltClassifier(classifier);
        return this;
    }

    @Override // weka.core.Aggregateable
    public void finalizeAggregation() throws Exception {
    }

    public static void main(String[] strArr) {
        runClassifier(new Vote(), strArr);
    }
}
